{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#######################################################################################################\n",
    "# Summary\n",
    "# 1. PyTorch Multi-GPU example\n",
    "# 2. On-the-fly data-augmentation (random crop, random flip)\n",
    "# Slightly rewritten for 0.4.0+ API\n",
    "#######################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MULTI_GPU = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torchvision.models as models\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.nn.init as init\n",
    "import torchvision.transforms as transforms\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from sklearn.metrics.ranking import roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from PIL import Image\n",
    "from common.utils import download_data_chextxray, get_imgloc_labels, get_train_valid_test_split\n",
    "from common.utils import compute_roc_auc, get_cuda_version, get_cudnn_version, get_gpu_name\n",
    "from common.utils import yield_mb\n",
    "from common.params_dense import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) \n",
      "[GCC 7.2.0]\n",
      "PyTorch:  0.4.0\n",
      "Numpy:  1.14.1\n",
      "GPU:  ['Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB']\n",
      "CUDA Version 9.0.176\n",
      "CuDNN Version  7.0.5\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"PyTorch: \", torch.__version__)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPUs:  24\n",
      "GPUs:  4\n"
     ]
    }
   ],
   "source": [
    "CPU_COUNT = multiprocessing.cpu_count()\n",
    "GPU_COUNT = len(get_gpu_name())\n",
    "print(\"CPUs: \", CPU_COUNT)\n",
    "print(\"GPUs: \", GPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chestxray/images chestxray/Data_Entry_2017.csv\n"
     ]
    }
   ],
   "source": [
    "# Model-params\n",
    "IMAGENET_RGB_MEAN_TORCH = [0.485, 0.456, 0.406]\n",
    "IMAGENET_RGB_SD_TORCH = [0.229, 0.224, 0.225]\n",
    "# Paths\n",
    "CSV_DEST = \"chestxray\"\n",
    "IMAGE_FOLDER = os.path.join(CSV_DEST, \"images\")\n",
    "LABEL_FILE = os.path.join(CSV_DEST, \"Data_Entry_2017.csv\")\n",
    "print(IMAGE_FOLDER, LABEL_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually scale to multi-gpu\n",
    "assert torch.cuda.is_available()\n",
    "_DEVICE = torch.device(\"cuda:0\")\n",
    "# enables cudnn's auto-tuner\n",
    "torch.backends.cudnn.benchmark=True\n",
    "if MULTI_GPU:\n",
    "    LR *= GPU_COUNT \n",
    "    BATCHSIZE *= GPU_COUNT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please make sure to download\n",
      "https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\n",
      "Data already exists\n",
      "CPU times: user 590 ms, sys: 260 ms, total: 851 ms\n",
      "Wall time: 850 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Download data\n",
    "# Wall time: 17min 58s\n",
    "print(\"Please make sure to download\")\n",
    "print(\"https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\")\n",
    "download_data_chextxray(CSV_DEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalise by imagenet mean/sd\n",
    "normalize = transforms.Normalize(IMAGENET_RGB_MEAN_TORCH,\n",
    "                                 IMAGENET_RGB_SD_TORCH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XrayData(Dataset):\n",
    "    def __init__(self, img_dir, lbl_file, patient_ids, transform=None):\n",
    "        \n",
    "        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)\n",
    "        self.transform = transform\n",
    "        print(\"Loaded {} labels and {} images\".format(len(self.labels), len(self.img_locs)))\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        im_file = self.img_locs[idx]\n",
    "        im_rgb = Image.open(im_file)\n",
    "        label = self.labels[idx]\n",
    "        if self.transform is not None:\n",
    "            im_rgb = self.transform(im_rgb)\n",
    "        return im_rgb, torch.FloatTensor(label)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.img_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_augmentation_dataset(img_dir, lbl_file, patient_ids, normalize):\n",
    "    dataset = XrayData(img_dir, lbl_file, patient_ids,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.Resize(WIDTH),\n",
    "                           transforms.ToTensor(),  \n",
    "                           normalize]))\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train:21563 valid:3080 test:6162\n"
     ]
    }
   ],
   "source": [
    "train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 87306 labels and 87306 images\n"
     ]
    }
   ],
   "source": [
    "# Dataset for training\n",
    "train_dataset = XrayData(img_dir=IMAGE_FOLDER,\n",
    "                         lbl_file=LABEL_FILE,\n",
    "                         patient_ids=train_set,\n",
    "                         transform=transforms.Compose([\n",
    "                             transforms.RandomResizedCrop(size=WIDTH),\n",
    "                             transforms.RandomHorizontalFlip(),\n",
    "                             transforms.ToTensor(),  # need to convert image to tensor!\n",
    "                             normalize]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 7616 labels and 7616 images\n",
      "Loaded 17198 labels and 17198 images\n"
     ]
    }
   ],
   "source": [
    "valid_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, valid_set, normalize)\n",
    "test_dataset = no_augmentation_dataset(IMAGE_FOLDER, LABEL_FILE, test_set, normalize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_symbol(out_features=CLASSES, multi_gpu=MULTI_GPU):\n",
    "    model = models.densenet.densenet121(pretrained=True)\n",
    "    # Replace classifier (FC-1000) with (FC-14)\n",
    "    model.classifier = nn.Sequential(\n",
    "        nn.Linear(model.classifier.in_features, out_features), \n",
    "        nn.Sigmoid())\n",
    "    if multi_gpu:\n",
    "        model = nn.DataParallel(model)\n",
    "    # CUDA\n",
    "    model.to(_DEVICE)  \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_symbol(sym, lr=LR):\n",
    "    # BCE Loss since classes not mutually exclusive + Sigmoid FC-layer\n",
    "    cri = nn.BCELoss()\n",
    "    opt = optim.Adam(sym.parameters(), lr=lr, betas=(0.9, 0.999))\n",
    "    sch = ReduceLROnPlateau(opt, factor=0.1, patience=5, mode='min')\n",
    "    return opt, cri, sch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, dataloader, optimizer, criterion):\n",
    "    model.train()\n",
    "    print(\"Training epoch\")\n",
    "    loss_val = 0\n",
    "    for i, (data, target) in enumerate(dataloader): \n",
    "        # Get samples (both async)\n",
    "        data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n",
    "        # Forwards\n",
    "        output = model(data)\n",
    "        # Loss\n",
    "        loss = criterion(output, target)\n",
    "        # Back-prop\n",
    "        optimizer.zero_grad()\n",
    "        # Log the loss (before .backward())\n",
    "        loss_val += loss.item()\n",
    "        loss.backward()\n",
    "        optimizer.step()   \n",
    "    print(\"Training loss: {0:.4f}\".format(loss_val/i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid_epoch(model, dataloader, criterion, phase='valid', cl=CLASSES):\n",
    "    model.eval()\n",
    "    if phase == 'testing':\n",
    "        print(\"Testing epoch\")\n",
    "    else:\n",
    "        print(\"Validating epoch\")\n",
    "        \n",
    "    # Don't save gradients\n",
    "    with torch.no_grad():\n",
    "        if phase == 'testing':\n",
    "            # pre-allocate predictions\n",
    "            len_pred = len(dataloader)*(dataloader.batch_size)\n",
    "            num_lab = dataloader.dataset.labels.shape[-1]\n",
    "            out_pred = torch.cuda.FloatTensor(len_pred, num_lab).fill_(0)\n",
    "        loss_val = 0\n",
    "        for i, (data, target) in enumerate(dataloader): \n",
    "            # Get samples\n",
    "            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n",
    "             # Forwards\n",
    "            output = model(data)\n",
    "            # Loss\n",
    "            loss = criterion(output, target)\n",
    "            # Log the loss\n",
    "            loss_val += loss.item()\n",
    "            # Log for AUC\n",
    "            if phase == 'testing':\n",
    "                out_pred[output.size(0)*i:output.size(0)*(1+i)] = output.data\n",
    "        # Fina loss\n",
    "        loss_mean = loss_val/i \n",
    "    \n",
    "    if phase == 'testing':\n",
    "        out_gt = dataloader.dataset.labels\n",
    "        out_pred = out_pred.cpu().numpy()[:len(out_gt)]  # Trim padding\n",
    "        print(\"Test-Dataset loss: {0:.4f}\".format(loss_mean))\n",
    "        print(\"Test-Dataset AUC: {0:.4f}\".format(compute_roc_auc(out_gt, out_pred, cl)))\n",
    "    else:\n",
    "        print(\"Validation loss: {0:.4f}\".format(loss_mean))\n",
    "    return loss_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimal to use fewer workers than CPU_COUNT\n",
    "# DataLoaders\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=True, num_workers=6, pin_memory=True)\n",
    "# Using a bigger batch-size (than BATCHSIZE) for below worsens performance\n",
    "valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCHSIZE,\n",
    "                          shuffle=False, num_workers=6, pin_memory=True)\n",
    "test_loader = DataLoader(dataset=test_dataset, batch_size=BATCHSIZE,\n",
    "                         shuffle=False, num_workers=6, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Train CheXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/torchvision-0.2.1-py3.5.egg/torchvision/models/densenet.py:212: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 4.54 s, sys: 1.69 s, total: 6.23 s\n",
      "Wall time: 6.66 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load symbol\n",
    "chexnet_sym = get_symbol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.25 ms, sys: 0 ns, total: 2.25 ms\n",
      "Wall time: 2.25 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load optimiser, loss\n",
    "# Scheduler for LRPlateau is not used\n",
    "optimizer, criterion, scheduler = init_symbol(chexnet_sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training epoch\n",
      "Training loss: 0.1748\n",
      "Validating epoch\n",
      "Validation loss: 0.1533\n",
      "Epoch time: 138 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1590\n",
      "Validating epoch\n",
      "Validation loss: 0.1496\n",
      "Epoch time: 112 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1564\n",
      "Validating epoch\n",
      "Validation loss: 0.1503\n",
      "Epoch time: 111 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1553\n",
      "Validating epoch\n",
      "Validation loss: 0.1467\n",
      "Epoch time: 110 seconds\n",
      "Training epoch\n",
      "Training loss: 0.1542\n",
      "Validating epoch\n",
      "Validation loss: 0.1480\n",
      "Epoch time: 114 seconds\n",
      "CPU times: user 13min 19s, sys: 2min 42s, total: 16min 1s\n",
      "Wall time: 9min 43s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Main training loop: 26min 50s\n",
    "# 4 GPU - Main training loop: 9min 43s\n",
    "# Main train/val loop\n",
    "for j in range(EPOCHS):\n",
    "    stime = time.time()\n",
    "    train_epoch(chexnet_sym, train_loader, optimizer, criterion)\n",
    "    loss_val = valid_epoch(chexnet_sym, valid_loader, criterion)   \n",
    "    print(\"Epoch time: {0:.0f} seconds\".format(time.time()-stime))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Test CheXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5 µs, sys: 1e+03 ns, total: 6 µs\n",
      "Wall time: 11.7 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load model for testing\n",
    "# I comment this out to create a fair test against Keras\n",
    "#chexnet_sym_test = get_symbol()\n",
    "#chkpt = torch.load(\"best_chexnet.pth.tar\")\n",
    "#chexnet_sym_test.load_state_dict(chkpt['state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing epoch\n",
      "Test-Dataset loss: 0.1581\n",
      "Full AUC [0.8094505616256435, 0.8379730923933503, 0.7931988130183195, 0.8755269047391451, 0.8741294809854842, 0.8849574340745426, 0.73340398318621, 0.8697029153003996, 0.6323058872339405, 0.829842783313489, 0.7465889838430303, 0.7875078370745876, 0.7606959923449286, 0.8821200691712012]\n",
      "Test-Dataset AUC: 0.8084\n",
      "CPU times: user 13.7 s, sys: 5.5 s, total: 19.2 s\n",
      "Wall time: 13.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU AUC: 0.8156\n",
    "# 4 GPU AUC: 0.8079\n",
    "#test_loss = valid_epoch(chexnet_sym_test, test_loader, criterion, 'testing')\n",
    "test_loss = valid_epoch(chexnet_sym, test_loader, criterion, 'testing')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Synthetic Data (Pure Training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87296\n"
     ]
    }
   ],
   "source": [
    "# Test on fake-data -> no IO lag\n",
    "batch_in_epoch = len(train_dataset.labels)//BATCHSIZE\n",
    "tot_num = batch_in_epoch * BATCHSIZE\n",
    "print(tot_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "fake_X = torch.tensor(np.random.rand(tot_num, 3, 224, 224).astype(np.float32))\n",
    "fake_y = torch.tensor(np.random.rand(tot_num, CLASSES).astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training epoch\n",
      "Training loss: 0.7160\n",
      "Training epoch\n",
      "Training loss: 0.6954\n",
      "Training epoch\n",
      "Training loss: 0.6955\n",
      "Training epoch\n",
      "Training loss: 0.6955\n",
      "Training epoch\n",
      "Training loss: 0.6955\n",
      "CPU times: user 12min 10s, sys: 1min 31s, total: 13min 42s\n",
      "Wall time: 8min 36s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Synthetic data: 25min 14s\n",
    "# 4 GPU - Synthetic data: 8min 27s\n",
    "for j in range(EPOCHS):\n",
    "    train_epoch(chexnet_sym, \n",
    "                yield_mb(fake_X, fake_y, BATCHSIZE, shuffle=False),\n",
    "                optimizer, \n",
    "                criterion)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
